{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Ensemble Design Pattern\n",
    "\n",
    "Stacking is an Ensemble method which combines the outputs of a collection of models to make a prediction. The initial models, which are typically of different model types, are trained to completion on the full training dataset. Then, a secondary meta-model is trained using the initial model outputs as features. This second meta-model learns how to best combine the outcomes of the initial models to decrease the training error and can be any type of machine learning model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a Stacking Ensemble model\n",
    "\n",
    "In this notebook, we'll create an Ensemble of three neural network models and train on the natality dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import pandas as pd\n",
    "import tensorflow as tf\n",
    "\n",
    "from tensorflow import keras\n",
    "from tensorflow import feature_column as fc\n",
    "from tensorflow.keras import layers, models, Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>weight_pounds</th>\n",
       "      <th>is_male</th>\n",
       "      <th>mother_age</th>\n",
       "      <th>plurality</th>\n",
       "      <th>gestation_weeks</th>\n",
       "      <th>mother_race</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>7.749249</td>\n",
       "      <td>False</td>\n",
       "      <td>12</td>\n",
       "      <td>Single(1)</td>\n",
       "      <td>40</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>7.561856</td>\n",
       "      <td>True</td>\n",
       "      <td>12</td>\n",
       "      <td>Single(1)</td>\n",
       "      <td>40</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>7.187070</td>\n",
       "      <td>False</td>\n",
       "      <td>12</td>\n",
       "      <td>Single(1)</td>\n",
       "      <td>34</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>6.375769</td>\n",
       "      <td>True</td>\n",
       "      <td>12</td>\n",
       "      <td>Single(1)</td>\n",
       "      <td>36</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>7.936641</td>\n",
       "      <td>False</td>\n",
       "      <td>12</td>\n",
       "      <td>Single(1)</td>\n",
       "      <td>35</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   weight_pounds  is_male  mother_age  plurality  gestation_weeks  mother_race\n",
       "0       7.749249    False          12  Single(1)               40          1.0\n",
       "1       7.561856     True          12  Single(1)               40          2.0\n",
       "2       7.187070    False          12  Single(1)               34          3.0\n",
       "3       6.375769     True          12  Single(1)               36          2.0\n",
       "4       7.936641    False          12  Single(1)               35          NaN"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv(\"./data/babyweight_train.csv\")\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create our `tf.data` input pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Determine CSV, label, and key columns\n",
    "# Create list of string column headers, make sure order matches.\n",
    "CSV_COLUMNS = [\"weight_pounds\",\n",
    "               \"is_male\",\n",
    "               \"mother_age\",\n",
    "               \"plurality\",\n",
    "               \"gestation_weeks\",\n",
    "               \"mother_race\"]\n",
    "\n",
    "# Add string name for label column\n",
    "LABEL_COLUMN = \"weight_pounds\"\n",
    "\n",
    "# Set default values for each CSV column as a list of lists.\n",
    "# Treat is_male and plurality as strings.\n",
    "DEFAULTS = [[0.0], [\"null\"], [0.0], [\"null\"], [0.0], [\"0\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataset(file_path):\n",
    "    dataset = tf.data.experimental.make_csv_dataset(\n",
    "        file_path,\n",
    "        batch_size=15, # Artificially small to make examples easier to show.\n",
    "        label_name=LABEL_COLUMN,\n",
    "        select_columns=CSV_COLUMNS,\n",
    "        column_defaults=DEFAULTS,\n",
    "        num_epochs=1,\n",
    "        ignore_errors=True)\n",
    "    return dataset\n",
    "\n",
    "train_data = get_dataset(\"./data/babyweight_train.csv\")\n",
    "test_data = get_dataset(\"./data/babyweight_eval.csv\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check that our tf.data dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "is_male             : [b'True' b'True' b'True' b'False' b'True' b'True' b'False' b'True' b'True'\n",
      " b'False' b'True' b'True' b'True' b'True' b'False']\n",
      "mother_age          : [16. 17. 17. 18. 16. 17. 17. 18. 17. 17. 17. 17. 16. 16. 16.]\n",
      "plurality           : [b'Single(1)' b'Single(1)' b'Single(1)' b'Single(1)' b'Single(1)'\n",
      " b'Single(1)' b'Single(1)' b'Single(1)' b'Single(1)' b'Single(1)'\n",
      " b'Single(1)' b'Single(1)' b'Single(1)' b'Single(1)' b'Single(1)']\n",
      "gestation_weeks     : [38. 38. 40. 39. 41. 39. 40. 42. 40. 33. 33. 40. 39. 38. 39.]\n",
      "mother_race         : [b'2.0' b'2.0' b'2.0' b'2.0' b'2.0' b'0' b'1.0' b'1.0' b'0' b'2.0' b'2.0'\n",
      " b'1.0' b'1.0' b'2.0' b'2.0']\n"
     ]
    }
   ],
   "source": [
    "def show_batch(dataset):\n",
    "    for batch, label in dataset.take(1):\n",
    "        for key, value in batch.items():\n",
    "            print(\"{:20s}: {}\".format(key,value.numpy()))\n",
    "\n",
    "show_batch(train_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create our feature columns "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "numeric_columns = [fc.numeric_column(\"mother_age\"),\n",
    "                  fc.numeric_column(\"gestation_weeks\")]\n",
    "\n",
    "CATEGORIES = {\n",
    "    'plurality': [\"Single(1)\", \"Twins(2)\", \"Triplets(3)\",\n",
    "                  \"Quadruplets(4)\", \"Quintuplets(5)\", \"Multiple(2+)\"],\n",
    "    'is_male' : [\"True\", \"False\", \"Unknown\"],\n",
    "    'mother_race': [str(_) for _ in df.mother_race.unique()]\n",
    "}\n",
    "\n",
    "categorical_columns = []\n",
    "for feature, vocab in CATEGORIES.items():\n",
    "  cat_col = fc.categorical_column_with_vocabulary_list(\n",
    "        key=feature, vocabulary_list=vocab)\n",
    "  categorical_columns.append(fc.indicator_column(cat_col))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create our ensemble models\n",
    "\n",
    "We'll train three different neural network models. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /opt/conda/lib/python3.7/site-packages/tensorflow_core/python/feature_column/feature_column_v2.py:4267: IndicatorColumn._variable_shape (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "WARNING:tensorflow:From /opt/conda/lib/python3.7/site-packages/tensorflow_core/python/feature_column/feature_column_v2.py:4322: VocabularyListCategoricalColumn._num_buckets (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n"
     ]
    }
   ],
   "source": [
    "inputs = {colname: tf.keras.layers.Input(\n",
    "    name=colname, shape=(), dtype=\"float32\")\n",
    "    for colname in [\"mother_age\", \"gestation_weeks\"]}\n",
    "\n",
    "inputs.update({colname: tf.keras.layers.Input(\n",
    "    name=colname, shape=(), dtype=\"string\")\n",
    "    for colname in [\"is_male\", \"plurality\", \"mother_race\"]})\n",
    "\n",
    "dnn_inputs = layers.DenseFeatures(categorical_columns+numeric_columns)(inputs)\n",
    "\n",
    "# model_1\n",
    "model1_h1 = layers.Dense(50, activation=\"relu\")(dnn_inputs)\n",
    "model1_h2 = layers.Dense(30, activation=\"relu\")(model1_h1)\n",
    "model1_output = layers.Dense(1, activation=\"relu\")(model1_h2)\n",
    "model_1 = tf.keras.models.Model(inputs=inputs, outputs=model1_output, name=\"model_1\")\n",
    "\n",
    "# model_2\n",
    "model2_h1 = layers.Dense(64, activation=\"relu\")(dnn_inputs)\n",
    "model2_h2 = layers.Dense(32, activation=\"relu\")(model2_h1)\n",
    "model2_output = layers.Dense(1, activation=\"relu\")(model2_h2)\n",
    "model_2 = tf.keras.models.Model(inputs=inputs, outputs=model2_output, name=\"model_2\")\n",
    "\n",
    "# model_3\n",
    "model3_h1 = layers.Dense(32, activation=\"relu\")(dnn_inputs)\n",
    "model3_output = layers.Dense(1, activation=\"relu\")(model3_h1)\n",
    "model_3 = tf.keras.models.Model(inputs=inputs, outputs=model3_output, name=\"model_3\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The function below trains a model and reports the MSE and RMSE on the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fit model on dataset\n",
    "def fit_model(model):\n",
    "    # define model\n",
    "    model.compile(\n",
    "        loss=tf.keras.losses.MeanSquaredError(),\n",
    "        optimizer='adam', metrics=['mse'])\n",
    "    # fit model\n",
    "    model.fit(train_data.shuffle(500), epochs=1)\n",
    "    \n",
    "    # evaluate model\n",
    "    test_loss, test_mse = model.evaluate(test_data)\n",
    "    print('\\n\\n{}:\\nTest Loss {}, Test RMSE {}'.format(\n",
    "        model.name, test_loss, test_mse**0.5))\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "directory already exists\n"
     ]
    }
   ],
   "source": [
    "# create directory for models\n",
    "try:\n",
    "    os.makedirs('models')\n",
    "except: \n",
    "    print(\"directory already exists\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll train each neural network and save the trained model to file. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "17638/17638 [==============================] - 127s 7ms/step - loss: 1.1232 - mse: 1.1232\n",
      "   4343/Unknown - 24s 6ms/step - loss: 2.9008 - mse: 2.9011\n",
      "\n",
      "model_1:\n",
      "Test Loss 2.9008474301768974, Test RMSE 1.7032530850184173\n",
      "Saved models/model_1.h5\n",
      "\n",
      "17638/17638 [==============================] - 117s 7ms/step - loss: 1.1097 - mse: 1.1097\n",
      "   4343/Unknown - 23s 5ms/step - loss: 2.0815 - mse: 2.0817\n",
      "\n",
      "model_2:\n",
      "Test Loss 2.081478821759177, Test RMSE 1.4428068143465294\n",
      "Saved models/model_2.h5\n",
      "\n",
      "17638/17638 [==============================] - 114s 6ms/step - loss: 1.1293 - mse: 1.1293\n",
      "   4343/Unknown - 23s 5ms/step - loss: 2.4174 - mse: 2.4173\n",
      "\n",
      "model_3:\n",
      "Test Loss 2.417384987838966, Test RMSE 1.554769235887298\n",
      "Saved models/model_3.h5\n",
      "\n"
     ]
    }
   ],
   "source": [
    "members = [model_1, model_2, model_3]\n",
    "\n",
    "# fit and save models\n",
    "n_members = len(members)\n",
    "\n",
    "for i in range(n_members):\n",
    "    # fit model\n",
    "    model = fit_model(members[i])\n",
    "    # save model\n",
    "    filename = 'models/model_' + str(i + 1) + '.h5'\n",
    "    model.save(filename, save_format='tf')\n",
    "    print('Saved {}\\n'.format(filename))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The RMSE varies on each of the neural networks. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the trained models and create the stacked ensemble model. \n",
    "\n",
    "The function below loads the trained models and returns them in a list. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load trained models from file\n",
    "def load_models(n_models):\n",
    "    all_models = []\n",
    "    for i in range(n_models):\n",
    "        filename = 'models/model_' + str(i + 1) + '.h5'\n",
    "        # load model from file\n",
    "        model = models.load_model(filename)\n",
    "        # add to list of members\n",
    "        all_models.append(model)\n",
    "        print('>loaded %s' % filename)\n",
    "    return all_models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">loaded models/model_1.h5\n",
      ">loaded models/model_2.h5\n",
      ">loaded models/model_3.h5\n",
      "Loaded 3 models\n"
     ]
    }
   ],
   "source": [
    "# load all models\n",
    "members = load_models(n_members)\n",
    "print('Loaded %d models' % len(members))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will need to freeze the layers of the pre-trained models since we won't train these models any further. The Stacked Ensemble will the trainable and learn how to best combine the results of the ensemble members. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# update all layers in all models to not be trainable\n",
    "for i in range(n_members):\n",
    "    model = members[i]\n",
    "    for layer in model.layers:\n",
    "        # make not trainable\n",
    "        layer.trainable = False\n",
    "        # rename to avoid 'unique layer name' issue\n",
    "        layer._name = 'ensemble_' + str(i+1) + '_' + layer.name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lastly, we'll create our Stacked Ensemble model. It is also a neural network. We'll use the Functional Keras API. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "member_inputs = [model.input for model in members]\n",
    "\n",
    "# concatenate merge output from each model\n",
    "member_outputs = [model.output for model in members]\n",
    "merge = layers.concatenate(member_outputs)\n",
    "h1 = layers.Dense(30, activation='relu')(merge)\n",
    "h2 = layers.Dense(20, activation='relu')(h1)\n",
    "h3 = layers.Dense(10, activation='relu')(h2)\n",
    "h4 = layers.Dense(5, activation='relu')(h2)\n",
    "ensemble_output = layers.Dense(1, activation='relu')(h3)\n",
    "ensemble_model = Model(inputs=member_inputs, outputs=ensemble_output)\n",
    "\n",
    "# plot graph of ensemble\n",
    "tf.keras.utils.plot_model(ensemble_model, show_shapes=True, to_file='ensemble_graph.png')\n",
    "\n",
    "# compile\n",
    "ensemble_model.compile(loss='mse', optimizer='adam', metrics=['mse'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need to adapt our `tf.data` pipeline to accommodate the multiple inputs for our Stacked Ensemble model. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "FEATURES = [\"is_male\", \"mother_age\", \"plurality\",\n",
    "            \"gestation_weeks\", \"mother_race\"]\n",
    "\n",
    "# stack input features for our tf.dataset\n",
    "def stack_features(features, label):\n",
    "    for feature in FEATURES:\n",
    "        for i in range(n_members):\n",
    "            features['ensemble_' + str(i+1) + '_' + feature] = features[feature]\n",
    "        \n",
    "    return features, label\n",
    "\n",
    "ensemble_data = train_data.map(stack_features).repeat(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "17638/17638 [==============================] - 165s 9ms/step - loss: 1.2281 - mse: 1.2281\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f89d832e710>"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ensemble_model.fit(ensemble_data.shuffle(500), epochs=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lastly, we will evaluate our Stacked Ensemble against the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   4343/Unknown - 35s 8ms/step - loss: 2.0102 - mse: 2.0102- 35s 8ms/step - loss: 2.0125 -"
     ]
    }
   ],
   "source": [
    "val_loss, val_mse = ensemble_model.evaluate(test_data.map(stack_features))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation RMSE: 1.4178322129942251\n"
     ]
    }
   ],
   "source": [
    "print(\"Validation RMSE: {}\".format(val_mse**0.5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright 2020 Google Inc. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
